package jwt

import (
	"context"
	"encoding/json"
	"github.com/go-kit/kit/log"
	"net/http"
)

type JWTMiddleware struct {
	next   http.Handler
	logger log.Logger
	JWTS   JWTService
}

func NewJWTMiddleware(next http.Handler, logger log.Logger, jwts JWTService) *JWTMiddleware {
	return &JWTMiddleware{
		next:   next,
		logger: logger,
		JWTS:   jwts,
	}
}

type User struct {
	Id       string
	Username string
	Roles    []string
}

type ValidateResponse struct {
	Error string `json:"error"`
}

func NewValidateResponse(text string) ([]byte, error) {
	resp := ValidateResponse{Error: text}
	result, err := json.Marshal(resp)
	if err != nil {
		return nil, err
	}
	return result, nil
}

func (m *JWTMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	tokenString := r.Header.Get("token")
	if len(tokenString) == 0 {
		w.WriteHeader(403)
		resp, err := NewValidateResponse("Please login")
		if err != nil {
			return
		}
		_, err = w.Write(resp)
		if err != nil {
			_ = m.logger.Log("write login response error:", err)
		}
		return
	}
	claims, err := m.JWTS.ParseToken(tokenString)
	if err != nil {
		if err == TokenExpired {
			w.WriteHeader(403)
			resp, err := NewValidateResponse("token已过期")
			if err != nil {
				return
			}
			_, err = w.Write(resp)
			if err != nil {
				_ = m.logger.Log("write token expired error:", err)
				return
			}
		} else if err == TokenMalformed {
			w.WriteHeader(403)
			resp, err := NewValidateResponse("token格式错误")
			if err != nil {
				return
			}
			_, err = w.Write(resp)
			if err != nil {
				_ = m.logger.Log("write token format error:", err)
				return
			}
		} else {
			w.WriteHeader(401)
			resp, err := NewValidateResponse("非常遗憾,您暂时没有权限访问此内容")
			if err != nil {
				return
			}
			_, err = w.Write(resp)
			if err != nil {
				_ = m.logger.Log("write token access error:", err)
				return
			}
		}
		return
	}
	if claims == nil {
		w.WriteHeader(403)
		resp, err := NewValidateResponse("token无效")
		if err != nil {
			return
		}
		_, err = w.Write(resp)
		if err != nil {
			_ = m.logger.Log("write token claims error:", err)
			return
		}
	}
	user := &User{
		Id:       claims.Uid,
		Username: claims.Username,
		Roles:    claims.Role,
	}
	ctx := context.WithValue(r.Context(), "user", user)
	r = r.WithContext(ctx)
	m.next.ServeHTTP(w, r)
	return
}
